{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Langfuse Data Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from datetime import datetime, timedelta\n",
    "from typing import Optional\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from itables import init_notebook_mode, show\n",
    "from langfuse import Langfuse\n",
    "from langfuse.api.resources.commons.types.observations_view import ObservationsView\n",
    "from pydantic import BaseModel\n",
    "import pandas as pd\n",
    "import plotly.express as px"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Utilities Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_langfuse_client(public_key: str, secret_key: str, host: str):\n",
    "    return Langfuse(\n",
    "        public_key=public_key,\n",
    "        secret_key=secret_key,\n",
    "        host=host,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_traces(client: Langfuse, name: Optional[str]=None, from_timestamp: Optional[datetime]=None, to_timestamp: Optional[datetime]=None):\n",
    "    traces = []\n",
    "    page = 1\n",
    "\n",
    "    while True:\n",
    "        data = client.fetch_traces(name=name, page=page, from_timestamp=from_timestamp, to_timestamp=to_timestamp).data\n",
    "        if len(data) == 0:\n",
    "            break\n",
    "        traces += data\n",
    "        page += 1\n",
    "\n",
    "    return traces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_observations(\n",
    "    client: Langfuse, \n",
    "    name: Optional[str]=None, \n",
    "    from_timestamp: Optional[datetime]=None, \n",
    "    to_timestamp: Optional[datetime]=None\n",
    ") -> list[ObservationsView]:\n",
    "    observations = []\n",
    "    page = 1\n",
    "\n",
    "    while True:\n",
    "        data = client.fetch_observations(name=name, page=page, from_start_time=from_timestamp, to_start_time=to_timestamp).data\n",
    "        if len(data) == 0:\n",
    "            break\n",
    "        observations += data\n",
    "        page += 1\n",
    "\n",
    "    return observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pprint_json(data):\n",
    "    print(json.dumps(json.loads(data), indent=2, ensure_ascii=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_error_results_by_query(error_results, query):\n",
    "    results = []\n",
    "    for error_type in error_results.keys():\n",
    "        results += list(\n",
    "            filter(\n",
    "                lambda error_result: error_result.dict()['input']['args'][0]['query'] == query,\n",
    "                error_results[error_type]\n",
    "            )\n",
    "        )\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_error_result_details(error_results):\n",
    "    error_results_details = {}\n",
    "    for result in error_results:\n",
    "        error_type = result.output['metadata']['error_type']\n",
    "        if error_type not in error_results_details:\n",
    "            error_results_details[error_type] = [result]\n",
    "        else:\n",
    "            error_results_details[error_type].append(result)\n",
    "    return error_results_details"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_traces_by_conditions(traces, conditions):\n",
    "    def get_traces_with_some_conditions(traces, conditions):\n",
    "        results = []\n",
    "        for trace in traces:\n",
    "            match = True\n",
    "            for key, value in conditions.items():\n",
    "                if key == \"metadata\":\n",
    "                    for meta_key, meta_value in value.items():\n",
    "                        if trace.metadata.get(meta_key) != meta_value:\n",
    "                            match = False\n",
    "                            break\n",
    "                elif getattr(trace, key, None) != value:\n",
    "                    match = False\n",
    "                    break\n",
    "            if match:\n",
    "                results.append(trace)\n",
    "        return results\n",
    "    \n",
    "    def get_trace_results_by_type(traces):\n",
    "        error_results = []\n",
    "        no_error_results = []\n",
    "        for trace in traces:\n",
    "            if trace.metadata.get('error_type', ''):\n",
    "                error_results.append(trace)\n",
    "            else:\n",
    "                no_error_results.append(trace)\n",
    "\n",
    "        assert len(error_results) + len(no_error_results) == len(traces)\n",
    "\n",
    "        return error_results, no_error_results\n",
    "\n",
    "    _traces = get_traces_with_some_conditions(traces, conditions)\n",
    "    print(f'number of traces: {len(_traces)}')\n",
    "\n",
    "    error_results, no_error_results = get_trace_results_by_type(_traces)\n",
    "    print(f'# of error results: {len(error_results)}')\n",
    "    print(f'# of no error results: {len(no_error_results)}')\n",
    "    print(f'ratio of failed traces: {len(error_results) / len(_traces)}')\n",
    "\n",
    "    return error_results, no_error_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_traces_group_by_value(traces, name, value):\n",
    "    results = {}\n",
    "    for trace in traces:\n",
    "        if trace.name == name:\n",
    "            if _val := trace.metadata.get(value, ''):\n",
    "                if _val not in results:\n",
    "                    results[_val] = [trace]\n",
    "                else:\n",
    "                    results[_val].append(trace)\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analysis Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_notebook_mode(all_interactive=True)\n",
    "load_dotenv(\".env\", override=True)\n",
    "\n",
    "client = init_langfuse_client(\n",
    "    os.getenv(\"LANGFUSE_PUBLIC_KEY\"),\n",
    "    os.getenv(\"LANGFUSE_SECRET_KEY\"),\n",
    "    os.getenv(\"LANGFUSE_HOST\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get all traces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "traces = get_all_traces(client)\n",
    "len(traces)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get all Intent Classification observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IntentClassificationData(BaseModel):\n",
    "    class Input(BaseModel):\n",
    "        query: str\n",
    "        project_id: Optional[str] = None\n",
    "        history: Optional[dict] = None\n",
    "\n",
    "\n",
    "    class Output(BaseModel):\n",
    "        intent: str\n",
    "        db_schemas: list[str]\n",
    "\n",
    "    id: str\n",
    "    project_id: str\n",
    "    trace_id: str\n",
    "    start_time: datetime\n",
    "    end_time: datetime\n",
    "    input: Input\n",
    "    output: Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "observations = get_all_observations(client, name=\"Intent Classification\")\n",
    "len(observations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "def group_observations_by_intent(observations: list[IntentClassificationData]):\n",
    "    intents = set([observation.output.intent for observation in observations])\n",
    "    results = {}\n",
    "    for intent in intents:\n",
    "        results[intent] = [observation for observation in observations if observation.output.intent == intent]\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "observation_results = []\n",
    "for observation in observations:\n",
    "    observation_results.append(IntentClassificationData(\n",
    "        trace_id=observation.trace_id,\n",
    "        start_time=observation.start_time,\n",
    "        end_time=observation.end_time,\n",
    "        project_id=observation.projectId,\n",
    "        id=observation.id,\n",
    "        input=IntentClassificationData.Input(\n",
    "            query=observation.input['kwargs']['query'],\n",
    "            project_id=observation.input['kwargs']['id'],\n",
    "            history=observation.input['kwargs']['history'],\n",
    "        ),\n",
    "        output=IntentClassificationData.Output(\n",
    "            intent=observation.output['post_process']['intent'],\n",
    "            db_schemas=observation.output['post_process']['db_schemas'],\n",
    "        ),\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pprint_json(observation_results[0].model_dump_json())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for intent, observations in group_observations_by_intent(observation_results).items():\n",
    "    print(f'intent: {intent}')\n",
    "    print(f'number of observations: {len(observations)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, observations in group_observations_by_intent(observation_results).items():\n",
    "    for observation in observations:\n",
    "        query = observation.input.query\n",
    "        if observation.input.history:\n",
    "            summaries = [step['summary'] for step in observation.input.history.get('steps', []) if step.get('summary', '')]\n",
    "            query = ' '.join(summaries) + ' ' + observation.input.query\n",
    "        print(f'project_id: {observation.input.project_id}')\n",
    "        print(f'url: {os.getenv('LANGFUSE_HOST')}/project/{observation.project_id}/traces/{observation.trace_id}?observation={observation.id}')\n",
    "        print(f'query: {query}')\n",
    "        print(f'intent: {observation.output.intent}')\n",
    "        print('-' * 100)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Trace names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set(trace.name for trace in traces)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "RELEASE = \"0.8.17\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Traces: Prepare Semantics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_mdl_table_distribution(traces):\n",
    "    results = []\n",
    "    for result in traces:\n",
    "        if isinstance(result.input, dict):\n",
    "            results.append(\n",
    "                len(json.loads(result.input['args'][0]['mdl'])['models'])\n",
    "            )\n",
    "\n",
    "    df = pd.DataFrame(results, columns=['table_num'])\n",
    "    fig = px.histogram(df, x=\"table_num\")\n",
    "    fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_mdl_columns_per_table_distribution(traces):\n",
    "    results = []\n",
    "    for result in traces:\n",
    "        if isinstance(result.input, dict):\n",
    "            mdl = json.loads(result.input['args'][0]['mdl'])\n",
    "            for model in mdl['models']:\n",
    "                results.append(len(model['columns']))\n",
    "\n",
    "    df = pd.DataFrame(results, columns=['column_num'])\n",
    "    fig = px.histogram(df, x=\"column_num\")\n",
    "    fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conditions = {\n",
    "    \"metadata\": {\n",
    "    },\n",
    "    \"name\": \"Prepare Semantics\",\n",
    "    # \"release\": RELEASE,\n",
    "}\n",
    "error_results, no_error_results = get_traces_by_conditions(traces, conditions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set([model.metadata['mdl_hash'] for model in error_results + no_error_results])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_mdl_table_distribution(error_results + no_error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_mdl_columns_per_table_distribution(error_results + no_error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_results_details = get_error_result_details(error_results)\n",
    "\n",
    "for key, value in error_results_details.items():\n",
    "    print(key)\n",
    "    print(len(value))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Traces: Ask Question"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dry_run_failed_cases(client, from_timestamp = None, to_timestamp = None):\n",
    "    results = []\n",
    "    for data in get_all_observations(client, name=\"post_process\", from_timestamp=from_timestamp, to_timestamp=to_timestamp):\n",
    "        if invalid_generation_results := data.output.get('invalid_generation_results', []):\n",
    "            for result in invalid_generation_results:\n",
    "                result['trace_id'] = data.trace_id\n",
    "            results += invalid_generation_results\n",
    "    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conditions = {\n",
    "    \"metadata\": {\n",
    "    },\n",
    "    \"name\": \"Ask Question\",\n",
    "    # \"release\": RELEASE,\n",
    "}\n",
    "error_results, no_error_results = get_traces_by_conditions(traces, conditions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "get number of ask question traces grouped by mdl_hash"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_traces_by_mdl_hash = get_traces_group_by_value(error_results, \"Ask Question\", \"mdl_hash\")\n",
    "sorted_traces_by_mdl_hash = sorted(_traces_by_mdl_hash.items(), key=lambda x: len(x[1]), reverse=True)\n",
    "\n",
    "print(f'number of mdl_hash: {len(sorted_traces_by_mdl_hash)}')\n",
    "for mdl_hash, traces in sorted_traces_by_mdl_hash:\n",
    "    print(f'mdl_hash: {mdl_hash}')\n",
    "    print(f'size of traces: {len(traces)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "show ask question failed reaons groupbed by error_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dry_run_failed_cases = get_dry_run_failed_cases(client, from_timestamp=datetime.now() - timedelta(days=14))\n",
    "len(dry_run_failed_cases)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show(\n",
    "    pd.DataFrame(dry_run_failed_cases),\n",
    "    buttons=[\"csvHtml5\"],\n",
    "    layout={\"top1\": \"searchPanes\"},\n",
    "    searchPanes={\"layout\": \"columns-3\", \"cascadePanes\": True}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_results_details = get_error_result_details(error_results)\n",
    "for key, value in error_results_details.items():\n",
    "    print(key)\n",
    "    print(len(value))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in error_results_details.keys():\n",
    "    print(f'Error Type: {key}')\n",
    "    for error_result in error_results_details[key]:\n",
    "        pprint_json(error_result.json())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_error_results = get_error_results_by_query(error_results_details, '我在台中公園，有哪些路線我可以搭乘？')\n",
    "len(_error_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _error_result in _error_results:\n",
    "    pprint_json(_error_result.json())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Traces: Ask Details(Breakdown SQL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conditions = {\n",
    "    \"metadata\": {\n",
    "    },\n",
    "    \"name\": \"Ask Details(Breakdown SQL)\",\n",
    "    \"release\": RELEASE,\n",
    "}\n",
    "error_results, no_error_results = get_traces_by_conditions(traces, conditions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wren-ai-service-rIOQoSXj-py3.12",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
